from torch.utils.data import Dataset
from PIL import Image
import os

class MyDataset(Dataset):
  def __init__(self,root_dir,label_dir):
    self.root_dir = root_dir
    self.label_dir = label_dir
    self.path = os.path.join(self.root_dir,self.label_dir)
    self.img_path = sorted(os.listdir(self.path))

  def __getitem__(self,index):
    img_name = self.img_path[index]
    img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
    img = Image.open(img_item_path)
    label = self.label_dir
    return img, label

  def __len__(self):
    return len(self.img_path)

ants_dataset = MyDataset(root_dir='week1/PyTorch学习/dataset/train', label_dir='ants')
bees_dataset = MyDataset(root_dir='week1/PyTorch学习/dataset/train', label_dir='bees')
train_dataset = ants_dataset + bees_dataset # Combine datasets

img_ants, label = ants_dataset[0]
img_ants.show()
img_bees, label = bees_dataset[0]
img_bees.show()
